from __future__ import annotations
from functools import reduce
from pathlib import Path
from typing import Any, Callable, Dict, List, Type, Union, cast
import json
from typing_extensions import TypeGuard
from nacolla.models import ImmutableModel
from nacolla.operations.merge import merge
from nacolla.parsing.implementation_map_file_specification import (
    KWARG_TYPE,
    ImplementationMapSpecification,
)
from nacolla.parsing.parse_implementation import IMPLEMENTATION, parse_implementation
from nacolla.stateful_callable import StatefulCallable, make_step
from nacolla.step import Step

IMPLEMENTATION_MAP = Dict[str, Step[ImmutableModel, ImmutableModel]]


def parse_implementation_map(
    implementation_map_file: Path,
) -> IMPLEMENTATION_MAP:
    specification: ImplementationMapSpecification = ImplementationMapSpecification(
        **json.loads(implementation_map_file.read_text())
    )

    implementation_dict: IMPLEMENTATION_MAP = {}
    for step in specification.implementations:
        if step.name in implementation_dict.keys():
            raise ValueError(
                "Steps must have unique names, encountered '" + step.name + "' twice"
            )
        parsed_implementation: IMPLEMENTATION = parse_implementation(
            import_definition=step.callable
        )

        if isinstance(parsed_implementation, dict):
            kwargs = cast(
                Dict[str, KWARG_TYPE], step.kwargs
            )  # validated in pydantic specification
            step_accumulator: List[Step[ImmutableModel, ImmutableModel]] = []
            for callable_name, callable_func in parsed_implementation.items():
                step_accumulator.append(
                    _parse_step(callable_name, callable_func, kwargs=kwargs)
                )
            implementation_dict[step.name] = reduce(merge, step_accumulator)
        else:
            implementation_dict[step.name] = _parse_step(
                step.name, parsed_implementation
            )

    return implementation_dict


def _parse_step(
    callable_name: str,
    callable_func: Union[
        Callable[[ImmutableModel], ImmutableModel],
        Type[StatefulCallable[ImmutableModel, ImmutableModel]],
    ],
    kwargs: Dict[str, Any] = {},
):
    if kwargs.get(callable_name):
        if _is_stateful_callable(callable_func):
            return make_step(callable_func(**kwargs), name=callable_name)
        else:
            raise ValueError(
                "Cannot pass kwargs to '"
                + str(callable_name)
                + "'.\n"
                + "kwargs can only be passed to stateful callables"
            )
    if not kwargs.get(callable_name):
        if _is_stateful_callable(callable_func):
            return make_step(callable_func(), name=callable_name)

        if _is_callable(to_check=callable_func):
            return Step[ImmutableModel, ImmutableModel](
                name=callable_name, apply=callable_func
            )

    raise ValueError(
        "Could not build step from '"
        + str(callable_func)
        + "' got from '"
        + str(callable_name)
    )


def _is_stateful_callable(
    to_check: Any,
) -> TypeGuard[Type[StatefulCallable[ImmutableModel, ImmutableModel]]]:
    return issubclass(to_check, StatefulCallable)


def _is_callable(
    to_check: Any,
) -> TypeGuard[Callable[[ImmutableModel], ImmutableModel]]:
    return callable(to_check)
